from .maml import Maml
from .maml_metric import MamlAcc
from .maml_loss import MamlLoss
from lib.sampler import EpisodeSampler
from lib.algorithms import ALGORITHMS


# note: MAML不能开半精度AMP, 否则在maml.py中会有一些莫名其妙的bug,
# 最要命的就是net和self.nets[type]的网络参数明明一模一样, 但是输入相同的数据,
# 得到的logits就是天差地别, 无语........

maml_config = {
    'update_lr': 1e-2,
    'update_step': 10
}


@ALGORITHMS.register('maml')
def maml(cfg):
    return {
        'net': Maml(cfg,maml_config),
        'loss': MamlLoss(cfg),
        'metric': MamlAcc(cfg),
        'train_sampler': EpisodeSampler,
        'validate_sampler': EpisodeSampler,
        'test_sampler': EpisodeSampler
    }
